{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_11 import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Serializing the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=2920)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 128\n",
    "bs = 64\n",
    "\n",
    "tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]\n",
    "val_tfms = [make_rgb, CenterCrop(size), np_to_float]\n",
    "il = ImageList.from_files(path, tfms=tfms)\n",
    "sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))\n",
    "ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())\n",
    "ll.valid.x.tfms = val_tfms\n",
    "data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(il)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = LabelSmoothingCrossEntropy()\n",
    "opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):\n",
    "    phases = create_phases(pct_start)\n",
    "    sched_lr  = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))\n",
    "    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))\n",
    "    return [ParamScheduler('lr', sched_lr),\n",
    "            ParamScheduler('mom', sched_mom)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 3e-3\n",
    "pct_start = 0.5\n",
    "cbsched = sched_1cycle(lr, pct_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(40, cbsched)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "st = learn.model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "type(st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "', '.join(st.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "st['10.bias']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl_path = path/'models'\n",
    "mdl_path.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(st, mdl_path/'iw5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3127)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pets = datasets.untar_data(datasets.URLs.PETS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pets.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pets_path = pets/'images'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "il = ImageList.from_files(pets_path, tfms=tfms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "il"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def random_splitter(fn, p_valid): return random.random() < p_valid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = il.items[0].name; n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "re.findall(r'^(.*)_\\d+.jpg$', n)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pet_labeler(fn): return re.findall(r'^(.*)_\\d+.jpg$', fn.name)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc = CategoryProcessor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ll = label_by_func(sd, pet_labeler, proc_y=proc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "', '.join(proc.vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ll.valid.x.tfms = val_tfms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_out = len(proc.vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(5, cbsched)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3265)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "st = torch.load(mdl_path/'iw5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = learn.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.load_state_dict(st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))\n",
    "m_cut = m[:cut]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xb,yb = get_batch(data.valid_dl, learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = m_cut(xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ni = pred.shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AdaptiveConcatPool2d(nn.Module):\n",
    "    def __init__(self, sz=1):\n",
    "        super().__init__()\n",
    "        self.output_size = sz\n",
    "        self.ap = nn.AdaptiveAvgPool2d(sz)\n",
    "        self.mp = nn.AdaptiveMaxPool2d(sz)\n",
    "    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nh = 40\n",
    "\n",
    "m_new = nn.Sequential(\n",
    "    m_cut, AdaptiveConcatPool2d(), Flatten(),\n",
    "    nn.Linear(ni*2, data.c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.model = m_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(5, cbsched)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## adapt_model and gradual unfreezing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3483)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adapt_model(learn, data):\n",
    "    cut = next(i for i,o in enumerate(learn.model.children())\n",
    "               if isinstance(o,nn.AdaptiveAvgPool2d))\n",
    "    m_cut = learn.model[:cut]\n",
    "    xb,yb = get_batch(data.valid_dl, learn)\n",
    "    pred = m_cut(xb)\n",
    "    ni = pred.shape[1]\n",
    "    m_new = nn.Sequential(\n",
    "        m_cut, AdaptiveConcatPool2d(), Flatten(),\n",
    "        nn.Linear(ni*2, data.c_out))\n",
    "    learn.model = m_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)\n",
    "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adapt_model(learn, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in learn.model[0].parameters(): p.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(3, sched_1cycle(1e-2, 0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in learn.model[0].parameters(): p.requires_grad_(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(5, cbsched, reset_opt=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch norm transfer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3567)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)\n",
    "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))\n",
    "adapt_model(learn, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_mod(m, f):\n",
    "    f(m)\n",
    "    for l in m.children(): apply_mod(l, f)\n",
    "\n",
    "def set_grad(m, b):\n",
    "    if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return\n",
    "    if hasattr(m, 'weight'):\n",
    "        for p in m.parameters(): p.requires_grad_(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_mod(learn.model, partial(set_grad, b=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(3, sched_1cycle(1e-2, 0.5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_mod(learn.model, partial(set_grad, b=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(5, cbsched, reset_opt=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pytorch already has an `apply` method we can use:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.model.apply(partial(set_grad, b=False));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Discriminative LR and param groups"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3799)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))\n",
    "adapt_model(learn, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bn_splitter(m):\n",
    "    def _bn_splitter(l, g1, g2):\n",
    "        if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()\n",
    "        elif hasattr(l, 'weight'): g1 += l.parameters()\n",
    "        for ll in l.children(): _bn_splitter(ll, g1, g2)\n",
    "        \n",
    "    g1,g2 = [],[]\n",
    "    _bn_splitter(m[0], g1, g2)\n",
    "    \n",
    "    g2 += m[1:].parameters()\n",
    "    return g1,g2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a,b = bn_splitter(learn.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(len(a)+len(b), len(list(m.parameters())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Learner.ALL_CBS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from types import SimpleNamespace\n",
    "cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cb_types.after_backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class DebugCallback(Callback):\n",
    "    _order = 999\n",
    "    def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f\n",
    "    def __call__(self, cb_name):\n",
    "        if cb_name==self.cb_name:\n",
    "            if self.f: self.f(self.run)\n",
    "            else:      set_trace()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):\n",
    "    phases = create_phases(pct_start)\n",
    "    sched_lr  = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))\n",
    "                 for lr in lrs]\n",
    "    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))\n",
    "    return [ParamScheduler('lr', sched_lr),\n",
    "            ParamScheduler('mom', sched_mom)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disc_lr_sched = sched_1cycle([0,3e-2], 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(xresnet18, data, loss_func, opt_func,\n",
    "                    c_out=10, norm=norm_imagenette, splitter=bn_splitter)\n",
    "\n",
    "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))\n",
    "adapt_model(learn, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _print_det(o): \n",
    "    print (len(o.opt.param_groups), o.opt.hypers)\n",
    "    raise CancelTrainException()\n",
    "\n",
    "learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(3, disc_lr_sched)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(5, disc_lr_sched)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!./notebook2script.py 11a_transfer_learning.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
